#!/usr/bin/env python

#####################################################################
#
# pfcstat is a tool for summarizing Priority-based Flow Control (PFC) statistics. 
#
#####################################################################

import swsssdk
import sys
import argparse
import cPickle as pickle
import datetime
import os.path

from collections import namedtuple, OrderedDict
from natsort import natsorted
from tabulate import tabulate


PStats = namedtuple("PStats", "pfc0, pfc1, pfc2, pfc3, pfc4, pfc5, pfc6, pfc7")
header_Rx = ['Port Rx', 'PFC0', 'PFC1', 'PFC2', 'PFC3', 'PFC4', 'PFC5', 'PFC6', 'PFC7']

header_Tx = ['Port Tx', 'PFC0', 'PFC1', 'PFC2', 'PFC3', 'PFC4', 'PFC5', 'PFC6', 'PFC7']

counter_bucket_rx_dict = {
    'SAI_PORT_STAT_PFC_0_RX_PKTS': 0,
    'SAI_PORT_STAT_PFC_1_RX_PKTS': 1,
    'SAI_PORT_STAT_PFC_2_RX_PKTS': 2,
    'SAI_PORT_STAT_PFC_3_RX_PKTS': 3,
    'SAI_PORT_STAT_PFC_4_RX_PKTS': 4,
    'SAI_PORT_STAT_PFC_5_RX_PKTS': 5,
    'SAI_PORT_STAT_PFC_6_RX_PKTS': 6,
    'SAI_PORT_STAT_PFC_7_RX_PKTS': 7
}

counter_bucket_tx_dict = {
    'SAI_PORT_STAT_PFC_0_TX_PKTS': 0,
    'SAI_PORT_STAT_PFC_1_TX_PKTS': 1,
    'SAI_PORT_STAT_PFC_2_TX_PKTS': 2,
    'SAI_PORT_STAT_PFC_3_TX_PKTS': 3,
    'SAI_PORT_STAT_PFC_4_TX_PKTS': 4,
    'SAI_PORT_STAT_PFC_5_TX_PKTS': 5,
    'SAI_PORT_STAT_PFC_6_TX_PKTS': 6,
    'SAI_PORT_STAT_PFC_7_TX_PKTS': 7
}

STATUS_NA = 'N/A'

COUNTER_TABLE_PREFIX = "COUNTERS:"
COUNTERS_PORT_NAME_MAP = "COUNTERS_PORT_NAME_MAP"

class Pfcstat(object):
    def __init__(self):
        self.db = swsssdk.SonicV2Connector(host='127.0.0.1')
        self.db.connect(self.db.COUNTERS_DB)

    def get_cnstat(self, rx):
        """
            Get the counters info from database.
        """
        def get_counters(table_id):
            """
                Get the counters from specific table.
            """
            fields = ["0","0","0","0","0","0","0","0"]
            if rx:
                bucket_dict = counter_bucket_rx_dict
            else:
                bucket_dict = counter_bucket_tx_dict
            for counter_name, pos in bucket_dict.iteritems():
                full_table_id = COUNTER_TABLE_PREFIX + table_id
                counter_data =  self.db.get(self.db.COUNTERS_DB, full_table_id, counter_name)
                if counter_data is None:
                    fields[pos] = STATUS_NA
                else:
                    fields[pos] = str(int(counter_data))
            cntr = PStats._make(fields)
            return cntr

        # Get the info from database
        counter_port_name_map = self.db.get_all(self.db.COUNTERS_DB, COUNTERS_PORT_NAME_MAP)
        # Build a dictionary of the stats
        cnstat_dict = OrderedDict()
        cnstat_dict['time'] = datetime.datetime.now()
        if counter_port_name_map is None:
            return cnstat_dict
        for port in natsorted(counter_port_name_map):
            cnstat_dict[port] = get_counters(counter_port_name_map[port]) 
        return cnstat_dict

    def cnstat_print(self, cnstat_dict, rx):
        """
            Print the cnstat.
        """
        table = []

        for key, data in cnstat_dict.iteritems():
            if key == 'time':
                continue
            table.append((key,
                        data.pfc0, data.pfc1,
                        data.pfc2, data.pfc3,
                        data.pfc4, data.pfc5,
                        data.pfc6, data.pfc7))

        if rx:
            print tabulate(table, header_Rx, tablefmt='simple', stralign='right')
        else:
            print tabulate(table, header_Tx, tablefmt='simple', stralign='right')

    def cnstat_diff_print(self, cnstat_new_dict, cnstat_old_dict, rx):
        """
            Print the difference between two cnstat results.
        """
        def ns_diff(newstr, oldstr):
            """
                Calculate the diff.
            """
            if newstr == STATUS_NA or oldstr == STATUS_NA:
                return STATUS_NA
            else:
                new, old = int(newstr), int(oldstr)
                return '{:,}'.format(new - old)

        table = []

        for key, cntr in cnstat_new_dict.iteritems():
            if key == 'time':
                continue
            old_cntr = None
            if key in cnstat_old_dict:
                old_cntr = cnstat_old_dict.get(key)

            if old_cntr is not None:
                table.append((key,
                            ns_diff(cntr.pfc0, old_cntr.pfc0),
                            ns_diff(cntr.pfc1, old_cntr.pfc1),
                            ns_diff(cntr.pfc2, old_cntr.pfc2),
                            ns_diff(cntr.pfc3, old_cntr.pfc3),
                            ns_diff(cntr.pfc4, old_cntr.pfc4),
                            ns_diff(cntr.pfc5, old_cntr.pfc5),
                            ns_diff(cntr.pfc6, old_cntr.pfc6),
                            ns_diff(cntr.pfc7, old_cntr.pfc7)))
            else:
                table.append((key,
                            cntr.pfc0, cntr.pfc1,
                            cntr.pfc2, cntr.pfc3,
                            cntr.pfc4, cntr.pfc5,
                            cntr.pfc6, cntr.pfc7))

        if rx:
            print tabulate(table, header_Rx, tablefmt='simple', stralign='right')
        else:
            print tabulate(table, header_Tx, tablefmt='simple', stralign='right')

def main():
    parser  = argparse.ArgumentParser(description='Display the pfc counters',
                                      version='1.0.0',
                                      formatter_class=argparse.RawTextHelpFormatter,
                                      epilog="""
Examples:
  pfcstat
  pfcstat -c
  pfcstat -d
""")

    parser.add_argument('-c', '--clear', action='store_true', help='Clear previous stats and save new ones')
    parser.add_argument('-d', '--delete', action='store_true', help='Delete saved stats')
    args = parser.parse_args()

    save_fresh_stats = args.clear
    delete_all_stats = args.delete

    uid = str(os.getuid())
    cnstat_file = uid

    cnstat_dir = "/tmp/pfcstat-" + uid
    cnstat_fqn_file_rx = cnstat_dir + "/" + cnstat_file + "rx"
    cnstat_fqn_file_tx = cnstat_dir + "/" + cnstat_file + "tx"

    pfcstat = Pfcstat()

    if delete_all_stats:
        for file in os.listdir(cnstat_dir):
            os.remove(cnstat_dir + "/" + file)

        try:
            os.rmdir(cnstat_dir)
            sys.exit(0)
        except IOError as e:
            print e.errno, e
            sys.exit(e)

    """
        Get the counters of pfc rx counter
    """
    cnstat_dict_rx = pfcstat.get_cnstat(True)

    """
        Get the counters of pfc tx counter
    """
    cnstat_dict_tx = pfcstat.get_cnstat(False)

    # At this point, either we'll create a file or open an existing one.
    if not os.path.exists(cnstat_dir):
        try:
            os.makedirs(cnstat_dir)
        except IOError as e:
            print e.errno, e
            sys.exit(1)

    if save_fresh_stats:
        try:
            pickle.dump(cnstat_dict_rx, open(cnstat_fqn_file_rx, 'w'))
            pickle.dump(cnstat_dict_tx, open(cnstat_fqn_file_tx, 'w'))
        except IOError as e:
            print e.errno, e
            sys.exit(e.errno)
        else:
            print "Clear saved counters"
            sys.exit(0)


    """
        Print the counters of pfc rx counter
    """
    if os.path.isfile(cnstat_fqn_file_rx):
        try:
            cnstat_cached_dict = pickle.load(open(cnstat_fqn_file_rx, 'r'))
            print "Last cached time was " + str(cnstat_cached_dict.get('time'))
            pfcstat.cnstat_diff_print(cnstat_dict_rx, cnstat_cached_dict, True)
        except IOError as e:
            print e.errno, e
    else:
        pfcstat.cnstat_print(cnstat_dict_rx, True)

    print
    """
        Print the counters of pfc tx counter
    """
    if os.path.isfile(cnstat_fqn_file_tx):
        try:
            cnstat_cached_dict = pickle.load(open(cnstat_fqn_file_tx, 'r'))
            print "Last cached time was " + str(cnstat_cached_dict.get('time'))
            pfcstat.cnstat_diff_print(cnstat_dict_tx, cnstat_cached_dict, False)
        except IOError as e:
            print e.errno, e
    else:
        pfcstat.cnstat_print(cnstat_dict_tx, False)

    sys.exit(0)

if __name__ == "__main__":
    main()
